import torch
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Block, GPT2Attention
from collections import defaultdict
import math
from datasets import load_dataset
import os
from tqdm import tqdm
# os.environ["CUDA_VISIBLE_DEVICES"] = "3"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
store_path = '/media/dataset2/huggingface'
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='validation',cache_dir=store_path)

def concatenate_texts(examples):
    return {'text': [' '.join(examples['text'])]}

concat_text = dataset.map(concatenate_texts, batched=True, batch_size=-1)['text'][0]

model_name = "gpt2"  # or "gpt2-medium", "gpt2-large" etc.
tokenizer = GPT2TokenizerFast.from_pretrained(model_name, cache_dir=store_path)
input_ids = tokenizer.encode(concat_text, return_tensors='pt')  # shape: (1, seq_len)

print("Total tokens:", input_ids.shape[1])

model = GPT2LMHeadModel.from_pretrained(model_name,cache_dir=store_path)
model.to(device)
model.eval()

# Helper functions for CKA computation
def linear_kernel(X, Y):
    return torch.matmul(X, Y.T)

def rbf(X, Y, sigma=None):
    GX = torch.matmul(X, Y.T)
    KX = torch.diag(GX) - GX + (torch.diag(GX) - GX).T
    if sigma is None:
        mdist = torch.median(KX[KX != 0])
        sigma = math.sqrt(mdist)
    KX *= -0.5 / (sigma * sigma)
    KX = torch.exp(KX)
    return KX

def HSIC(K, L):
    n = K.shape[0]
    H = torch.eye(n).to(K.device) - (1. / n) * torch.ones((n, n)).to(K.device)
    KH = torch.matmul(K, H)
    LH = torch.matmul(L, H)
    return 1. / ((n - 1) ** 2) * torch.trace(torch.matmul(KH, LH))

def CKA(X, Y, kernel=None):
    kernel = linear_kernel if kernel is None else kernel
    K = kernel(X, X)
    L = kernel(Y, Y)
    hsic = HSIC(K, L)
    varK = torch.sqrt(HSIC(K, K))
    varL = torch.sqrt(HSIC(L, L))
    return hsic / (varK * varL)

def hook_fn_capture(module, input, output, storage_dict, key, mode):
    if mode == "input":
        data = input[0].detach().to('cuda')
    elif mode == "output":
        data = output[0].detach().to('cuda')
    else:
        raise ValueError("Mode must be 'input' or 'output'")
    # print(data.shape)
    # Store each sample's activations separately
    # for i in range(data.shape[0]):  # Iterate over batch dimension
    storage_dict[key].append(data)  # Append per-sample tensor


def analyze_similarity(layer_type, mode="input"):
    model.eval()
    storage_dict = defaultdict(list)
    hooks = []
    max_length = model.config.n_positions
    stride = 1024
    seq_len = input_ids.size(1)
    # Register hooks on the specified layer type
    for name, module in model.named_modules():
        if isinstance(module, layer_type):
            hooks.append(
                module.register_forward_hook(
                    lambda m, i, o: hook_fn_capture(
                        m, i, o, storage_dict, module.__class__.__name__, mode
                    )
                )
            )

    per_sample_similarities = []

    for i in tqdm(range(1000)):  # Process 1000 batches(samples)
        begin_loc = max(i * stride + stride - max_length, 0)
        end_loc = min(i * stride + stride, seq_len)
        trg_len = end_loc - i * stride
        if trg_len <= 0:
            break        
        input_ids_chunk = input_ids[:, begin_loc:end_loc].to(device)
        target_ids = input_ids_chunk.clone()

        target_ids[:, :-trg_len] = -100
        
        # 1) forward
        with torch.no_grad():
            outputs = model(input_ids_chunk, labels=target_ids)

        # Extract activations for the current layer
        data_list = storage_dict[module.__class__.__name__]

        # Compare each sample with a reference (e.g., sample 0 in the batch)
        similarities = []
        for j in range(1, len(data_list)):  # Iterate over all samples
            current_sample = data_list[j].squeeze()  # Flatten the tensor
            reference_sample = data_list[j-1].squeeze()  # Use sample 0 as reference

            # Compute CKA for the current sample
            cka_similarity = CKA(current_sample, reference_sample)
            similarities.append(cka_similarity.item())

        per_sample_similarities.append(similarities)

        # Clear the storage_dict for the next batch
        storage_dict[module.__class__.__name__].clear()

    for hook in hooks:
        hook.remove()

    # Average similarities across batches for each sample
    averaged_similarities = [sum(sample_similarities) / len(sample_similarities) for sample_similarities in zip(*per_sample_similarities)]

    return {
        'per_sample_similarities': averaged_similarities,
    }

# Running analysis
results = {
    # 'attention_input': analyze_similarity(GPT2Attention, mode="input"),
    'attention_output': analyze_similarity(GPT2Attention, mode="output"),
    'mlp_input': analyze_similarity(GPT2MLP, mode="input"),
    'mlp_output': analyze_similarity(GPT2MLP, mode="output")
}

# Print results
for key, result in results.items():
    print(f"\nResults for {key}:")
    print(f"Per-Sample CKA similarities: {result['per_sample_similarities']}")
